// Copyright (c) 2017-2018 Intel Corporation
// Copyright (c) 2018 Huawei Corporation
//
// SPDX-License-Identifier: Apache-2.0
//

package manager

import (
	"context"
	"encoding/hex"
	"errors"
	"fmt"
	"sync"

	"github.com/sirupsen/logrus"

	"github.com/kata-containers/kata-containers/src/runtime/pkg/device/api"
	"github.com/kata-containers/kata-containers/src/runtime/pkg/device/config"
	"github.com/kata-containers/kata-containers/src/runtime/pkg/device/drivers"
	"github.com/kata-containers/kata-containers/src/runtime/virtcontainers/utils"
)

var (
	// ErrIDExhausted represents that devices are too many
	// and no more IDs can be generated
	ErrIDExhausted = errors.New("IDs are exhausted")
	// ErrDeviceNotExist represents device hasn't been created before
	ErrDeviceNotExist = errors.New("device with specified ID hasn't been created")
	// ErrDeviceNotAttached represents the device isn't attached
	ErrDeviceNotAttached = errors.New("device isn't attached")
	// ErrRemoveAttachedDevice represents the device isn't detached
	// so not allow to remove from list
	ErrRemoveAttachedDevice = errors.New("can't remove attached device")
)

type deviceManager struct {
	devices map[string]api.Device

	blockDriver        string
	vhostUserStorePath string

	sync.RWMutex

	vhostUserStoreEnabled bool

	vhostUserReconnectTimeout uint32
}

func deviceLogger() *logrus.Entry {
	return api.DeviceLogger().WithField("subsystem", "deviceManager")
}

// NewDeviceManager creates a deviceManager object behaved as api.DeviceManager
func NewDeviceManager(blockDriver string, vhostUserStoreEnabled bool, vhostUserStorePath string, vhostUserReconnect uint32, devices []api.Device) api.DeviceManager {
	dm := &deviceManager{
		vhostUserStoreEnabled:     vhostUserStoreEnabled,
		vhostUserStorePath:        vhostUserStorePath,
		vhostUserReconnectTimeout: vhostUserReconnect,
		devices:                   make(map[string]api.Device),
	}
	if blockDriver == config.VirtioMmio {
		dm.blockDriver = config.VirtioMmio
	} else if blockDriver == config.VirtioBlock {
		dm.blockDriver = config.VirtioBlock
	} else if blockDriver == config.Nvdimm {
		dm.blockDriver = config.Nvdimm
	} else if blockDriver == config.VirtioBlockCCW {
		dm.blockDriver = config.VirtioBlockCCW
	} else {
		dm.blockDriver = config.VirtioSCSI
	}

	config.PCIeDevices = make(map[config.PCIePort]config.PCIePortMapping)

	config.PCIeDevices[config.RootPort] = make(map[string]bool)
	config.PCIeDevices[config.SwitchPort] = make(map[string]bool)
	config.PCIeDevices[config.BridgePort] = make(map[string]bool)

	for _, dev := range devices {
		dm.devices[dev.DeviceID()] = dev
	}
	return dm
}

func (dm *deviceManager) findDevice(devInfo *config.DeviceInfo) api.Device {
	// For devices with a major of -1, we use the host path to find existing instances.
	if devInfo.Major == -1 {
		for _, dev := range dm.devices {
			dma, _ := dev.GetMajorMinor()
			if dma == -1 && dev.GetHostPath() == devInfo.HostPath {
				return dev
			}
		}
		return nil
	}

	for _, dev := range dm.devices {
		dma, dmi := dev.GetMajorMinor()
		if dma == devInfo.Major && dmi == devInfo.Minor {
			return dev
		}
	}
	return nil
}

// createDevice creates one device based on DeviceInfo
func (dm *deviceManager) createDevice(devInfo config.DeviceInfo) (dev api.Device, err error) {
	// pmem device may points to block devices or raw files,
	// do not change its HostPath.
	if !devInfo.Pmem {
		path, err := config.GetHostPathFunc(devInfo, dm.vhostUserStoreEnabled, dm.vhostUserStorePath)
		if err != nil {
			return nil, err
		}
		devInfo.HostPath = path
	}

	defer func() {
		if err == nil {
			dev.Reference()
		}
	}()

	if existingDev := dm.findDevice(&devInfo); existingDev != nil {
		return existingDev, nil
	}

	// device ID must be generated by manager instead of device itself
	// in case of ID collision
	if devInfo.ID, err = dm.newDeviceID(); err != nil {
		return nil, err
	}
	if IsVFIODevice(devInfo.HostPath) {
		return drivers.NewVFIODevice(&devInfo), nil
	} else if IsVhostUserBlk(devInfo) {
		if devInfo.DriverOptions == nil {
			devInfo.DriverOptions = make(map[string]string)
		}
		devInfo.DriverOptions[config.BlockDriverOpt] = dm.blockDriver
		devInfo.DriverOptions[config.VhostUserReconnectTimeOutOpt] = fmt.Sprintf("%d", dm.vhostUserReconnectTimeout)
		return drivers.NewVhostUserBlkDevice(&devInfo), nil
	} else if isBlock(devInfo) {
		if devInfo.DriverOptions == nil {
			devInfo.DriverOptions = make(map[string]string)
		}
		devInfo.DriverOptions[config.BlockDriverOpt] = dm.blockDriver
		return drivers.NewBlockDevice(&devInfo), nil
	} else {
		deviceLogger().WithField("device", devInfo.HostPath).Info("Device has not been passed to the container")
		return drivers.NewGenericDevice(&devInfo), nil
	}
}

// NewDevice creates a device based on specified DeviceInfo
func (dm *deviceManager) NewDevice(devInfo config.DeviceInfo) (api.Device, error) {
	dm.Lock()
	defer dm.Unlock()
	dev, err := dm.createDevice(devInfo)
	if err == nil {
		dm.devices[dev.DeviceID()] = dev
	}
	return dev, err
}

// RemoveDevice deletes the device from list based on specified device id
func (dm *deviceManager) RemoveDevice(id string) error {
	dm.Lock()
	defer dm.Unlock()
	dev, ok := dm.devices[id]
	if !ok {
		return ErrDeviceNotExist
	}

	if dev.Dereference() == 0 {
		if dev.GetAttachCount() > 0 {
			return ErrRemoveAttachedDevice
		}
		delete(dm.devices, id)
	}
	return nil
}

func (dm *deviceManager) newDeviceID() (string, error) {
	for i := 0; i < 5; i++ {
		// generate an random ID
		randBytes, err := utils.GenerateRandomBytes(8)
		if err != nil {
			return "", err
		}
		id := hex.EncodeToString(randBytes)

		// check ID collision, choose another one if ID is in use
		if _, ok := dm.devices[id]; !ok {
			return id, nil
		}
	}
	return "", ErrIDExhausted
}

func (dm *deviceManager) AttachDevice(ctx context.Context, id string, dr api.DeviceReceiver) error {
	dm.Lock()
	defer dm.Unlock()

	dev, ok := dm.devices[id]
	if !ok {
		return ErrDeviceNotExist
	}

	if err := dev.Attach(ctx, dr); err != nil {
		return err
	}
	return nil
}

func (dm *deviceManager) DetachDevice(ctx context.Context, id string, dr api.DeviceReceiver) error {
	dm.Lock()
	defer dm.Unlock()

	d, ok := dm.devices[id]
	if !ok {
		return ErrDeviceNotExist
	}
	if d.GetAttachCount() == 0 {
		return ErrDeviceNotAttached
	}

	if err := d.Detach(ctx, dr); err != nil {
		return err
	}
	return nil
}

func (dm *deviceManager) GetDeviceByID(id string) api.Device {
	dm.RLock()
	defer dm.RUnlock()
	if d, ok := dm.devices[id]; ok {
		return d
	}
	return nil
}

func (dm *deviceManager) GetAllDevices() []api.Device {
	dm.RLock()
	defer dm.RUnlock()
	devices := []api.Device{}
	for _, v := range dm.devices {
		devices = append(devices, v)
	}
	return devices
}

func (dm *deviceManager) IsDeviceAttached(id string) bool {
	dm.RLock()
	defer dm.RUnlock()
	d, ok := dm.devices[id]
	if !ok {
		return false
	}
	return d.GetAttachCount() > 0
}

// LoadDevices load devices from persist state
func (dm *deviceManager) LoadDevices(devStates []config.DeviceState) {
	dm.Lock()
	defer dm.Unlock()

	for _, ds := range devStates {
		var dev api.Device

		switch config.DeviceType(ds.Type) {
		case config.DeviceGeneric:
			dev = &drivers.GenericDevice{}
		case config.DeviceBlock:
			dev = &drivers.BlockDevice{}
		case config.DeviceVFIO:
			dev = &drivers.VFIODevice{}
		case config.VhostUserSCSI:
			dev = &drivers.VhostUserSCSIDevice{}
		case config.VhostUserBlk:
			dev = &drivers.VhostUserBlkDevice{}
		case config.VhostUserNet:
			dev = &drivers.VhostUserNetDevice{}
		default:
			deviceLogger().WithField("device-type", ds.Type).Warning("unrecognized device type is detected")
			// continue the for loop
			continue
		}

		dev.Load(ds)
		dm.devices[dev.DeviceID()] = dev
	}
}
